NNX: fix Linen-parity gaps + unit tests#4255
Merged
Merged
Conversation
0390217 to
aa18ab3
Compare
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
aa18ab3 to
bdf087f
Compare
6c939f7 to
8113532
Compare
544f4bd to
35ca6a0
Compare
NuojCheng
reviewed
Jun 26, 2026
NuojCheng
reviewed
Jun 26, 2026
NuojCheng
approved these changes
Jun 26, 2026
igorts-git
approved these changes
Jun 26, 2026
35ca6a0 to
5473916
Compare
With pure_nnx/enable_nnx/pure_nnx_decoder defaulting to True, several train/loss/decoder/metrics/GRPO paths diverged from Linen. Fixes: - skip_step_on_spikes: forward loss/grad_norm through apply_gradients to the optax skip-step optimizer; read is_skipped back off the NNX optimizer. - loss_fn: check the indexer dense-warmup before num_vocab_tiling (Linen order). - decoder logits guards: use the model_mode call-arg, not self.model_mode. - routed_bias read: dispatch the Linen intermediates path vs an NNX suffix match. - record_activation_metrics: collect by path suffix so it works for Linen and NNX, scanned and unscanned (also fixes a pre-existing Linen KeyError). - nnx_attrs_to_linen_vars: skip non-Variable attrs (qwix bookkeeping) not raise. - config: error when qwix quant can't reach a bridged Linen decoder under pure_nnx. - maxengine.set_engine_vars_from_base_engine: skip the quant copy and use the NNX kv-cache annotations on the NNX path. - GRPO _train_step_nnx: gradient-accumulation scan loop; fix the GA loss metric. - GRPO pathways reshard: drop the scan_layers=False NotImplementedError. - GRPO host-offload: move optimizer state to device before the in-place update. Tests: train_nnx_test, grpo_nnx_test, maxengine_nnx_test, nnx_quant_guard_test.
5473916 to
9ce8edd
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
NNX (
pure_nnx=True) had Linen-only / silently-divergent gaps across train / loss / decoder / metrics / GRPO. This closes them and adds correctness unit tests. The fixes apply onmainindependently; PR #3526 (flip defaults to NNX) makes them the live default path, and the UTs pin the behavior either way.Fixes
skip_step_on_spikes: silent no-op on NNX (apply_gradientsdidn't forwardloss/grad_norm). Now forwarded + metric surfaced.loss_fn: NNX checked vocab-tiling before indexer warm-up; reordered to match Linen.self.model_modeinstead of the call-argmodel_mode.routed_bias: updates silently dropped on NNX (Linen"intermediates"prefix absent on NNX dict). Now matched by suffix.record_internal_nn_metrics: KeyError on NNX. Now NNX-aware via suffix collection.pure_nnxwith the bridged decoder; bridge now skips qwix's non-Variableattrs + a config guard rejects bridged-decoder+qwix.maxengine.set_engine_vars_from_base_engine: AttributeError on NNX; now usesget_kv_cache_annotations_nnx.gradient_accumulation_steps>1: NotImplementedError on NNX. Implemented; also fixed the GA loss metric (sum/GA, notsum/total_weights).scan_layers=False: NotImplementedError on NNX. Guard removed (NNX policy already matches the inference layout).optimizer_memory_host_offload: ignored on NNX; now moves opt state to device before the update.Also re-declared the legacy GRPO config fields (
inference_replicas/inference_devices_per_replica/inference_rollouts/use_pathways_reshard) intypes.py— they were dropped from the schema sogrpo.ymlcouldn't load (pre-existing onmain).Tests
tests/unit/{train_nnx_test,grpo_nnx_test,maxengine_nnx_test,nnx_quant_guard_test}.py— 27 pass on CPU:Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.